{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6acbea5d",
   "metadata": {},
   "source": [
    "# Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension\n",
    "This notebook walks through a basic example of using the CUTLASS Python interface to declare\n",
    "a grouped GEMM kernel and export it as a PyTorch CUDA extension. Note that GEMM and Conv2d can also be exported as PyTorch CUDA extensions. \n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n",
    "\n",
    "## Background on grouped GEMM\n",
    "Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides)\n",
    "in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM,\n",
    "without the requirement that the sizes and strides of each GEMM be the same.\n",
    "\n",
    "For example, if one has `p` GEMMs with sizes:\n",
    "```text\n",
    "M_1 x N_1 x K_1\n",
    "M_2 x N_2 x K_2\n",
    "...\n",
    "M_p x N_p x K_p\n",
    "```\n",
    "CUTLASS's grouped GEMM will execute these in a single CUDA kernel.\n",
    "\n",
    "Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would\n",
    "insufficiently utilize the device in isolation.\n",
    "\n",
    "## Declaring a grouped GEMM via the CUTLASS Python interface\n",
    "A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one\n",
    "simply calls `cutlass.op.GroupedGemm`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdcf21d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cutlass\n",
    "import torch\n",
    "\n",
    "dtype = torch.float16\n",
    "plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "514f40a4",
   "metadata": {},
   "source": [
    "We can then compile and run this operation on a group of GEMMs. We'll first set up some utility functions to initialize GEMMs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2a7371e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "random.seed(2023)\n",
    "\n",
    "# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K\n",
    "def initialize(dtype, M, N, K):\n",
    "    sizes = [(M, K), (K, N), (M, N), (M, N)]\n",
    "    return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]\n",
    "\n",
    "# Utility function to generate `problems` GEMMs of random sizes\n",
    "def generate_problems(problems):\n",
    "    valid_sizes = [128, 256, 512, 1024]\n",
    "    As, Bs, Cs, Ds = [], [], [], []\n",
    "    for _ in range(problems):\n",
    "        M, N, K = [random.choice(valid_sizes) for _ in range(3)]\n",
    "        A, B, C, D = initialize(dtype, M, N, K)\n",
    "        As.append(A)\n",
    "        Bs.append(B)\n",
    "        Cs.append(C)\n",
    "        Ds.append(D)\n",
    "    return As, Bs, Cs, Ds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "590a3bc5",
   "metadata": {},
   "source": [
    "We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "776c9233",
   "metadata": {},
   "outputs": [],
   "source": [
    "As, Bs, Cs, Ds, = generate_problems(50)\n",
    "\n",
    "plan.run(As, Bs, Cs, Ds, print_module=True)\n",
    "Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
    "\n",
    "for d, d_torch in zip(Ds, Ds_torch):\n",
    "    assert torch.allclose(d, d_torch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "766e4f03",
   "metadata": {},
   "source": [
    "## Exporting the CUTLASS kernel to a PyTorch CUDA extension\n",
    "The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a [PyTorch CUDA extension](https://pytorch.org/tutorials/advanced/cpp_extension.html). This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.\n",
    "\n",
    "The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later \"ahead-of-time\" compilation, or be just-in-time compiled and returned to the user.\n",
    "\n",
    "To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a98dee6",
   "metadata": {},
   "outputs": [],
   "source": [
    "op = plan.construct()\n",
    "grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8ca3991",
   "metadata": {},
   "source": [
    "The `cutlass.emit.pytorch` function emits:\n",
    "* `out/grouped_gemm_kernel.cu`: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors\n",
    "* `out/grouped_gemm.cpp`: This file contains a C++ wrapper around the aforementioned CUTLASS kernel\n",
    "* `setup.py`: This file contains the `setuptools` script for building and installing the generated extension\n",
    "\n",
    "The extension can be build from within the `module_output` directory by running:\n",
    "```bash\n",
    "TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py install\n",
    "```\n",
    "Where `TORCH_ARCH_LIST` is set to the compute capability of the device on which the kernel will be run.\n",
    "\n",
    "See the PyTorch [\"Custom C++ and CUDA Extensions\"](https://pytorch.org/tutorials/advanced/cpp_extension.html) tutorial for more details on this.\n",
    "\n",
    "The PyTorch CUDA extension could be built for this module by running:\n",
    "```bash\n",
    "cd out\n",
    "TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py\n",
    "```\n",
    "(assuming that one is building for SM80)\n",
    "\n",
    "One could then use the kernel in a later PyTorch module by running:\n",
    "\n",
    "```python\n",
    "import torch\n",
    "import grouped_gemm\n",
    "\n",
    "grouped_gemm.run(As, Bs)\n",
    "```\n",
    "\n",
    "In this case, however, we set `jit=True`, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly.\n",
    "Under the hood, this leverages the [torch.utils.cpp_extension.load](https://pytorch.org/tutorials/advanced/cpp_extension.html) method\n",
    "and returns back the loaded extension.\n",
    "\n",
    "We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cecb26a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "Ds = grouped_gemm.run(As, Bs)\n",
    "Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
    "for d, d_torch in zip(Ds, Ds_torch):\n",
    "    assert torch.allclose(d, d_torch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50db80e4",
   "metadata": {},
   "source": [
    "Finally, we can profile our grouped GEMM extension:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b76805d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_warmup = 20\n",
    "num_profile = 100\n",
    "\n",
    "# Warmup iterations\n",
    "for _ in range(num_warmup):\n",
    "    Ds = grouped_gemm.run(As, Bs)\n",
    "    Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
    "    torch.cuda.synchronize()\n",
    "\n",
    "# Timing iterations\n",
    "import time\n",
    "grouped = 0\n",
    "nongrouped = 0\n",
    "for _ in range(num_profile):\n",
    "    start = time.time()\n",
    "    Ds = grouped_gemm.run(As, Bs)\n",
    "    torch.cuda.synchronize()\n",
    "    grouped += time.time() - start\n",
    "\n",
    "    start = time.time()\n",
    "    Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
    "    torch.cuda.synchronize()\n",
    "    nongrouped += time.time() - start\n",
    "\n",
    "print('Grouped:     {:.3f} us'.format(grouped * 1e6/num_profile))\n",
    "print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))\n",
    "print('Speedup: {:.3f}'.format(nongrouped / grouped))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
